Open
Conversation
Pin transformers to 4.x so diffusers 0.32.2 can import FLAX_WEIGHTS_NAME. Patch rocmProfileData rocpd_python Makefile to avoid pip install --user inside the base image venv. Flash-attention is installed with FLASH_ATTENTION_TRITON_AMD_ENABLE for headless/CI-friendly builds; replace prior wheel build from pinned SHA.
There was a problem hiding this comment.
Pull request overview
This PR updates AMD PyTorch Docker images to make flash-attention build reliably in headless CI and to prevent a runtime incompatibility between diffusers==0.32.2 and transformers v5 in the pyt_hy_video image.
Changes:
- Pin
transformersto<5alongsidediffusers==0.32.2inpyt_hy_videoto avoid an ImportError at runtime. - Switch
pyt_hy_videoflash-attention installation to the ROCm Triton path and patchrocmProfileData’srocpd_python/Makefileto avoidpip install --userinside a venv. - Add
gfx950to the ROCm arch list inpyt_mochi_inference.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| docker/pyt_mochi_inference.ubuntu.amd.Dockerfile | Extends ROCm arch list used for flash-attention wheel build. |
| docker/pyt_hy_video.ubuntu.amd.Dockerfile | Pins transformers to 4.x, changes flash-attention install approach, and patches rocmProfileData install behavior for venv compatibility. |
Comments suppressed due to low confidence (1)
docker/pyt_mochi_inference.ubuntu.amd.Dockerfile:40
PYTORCH_ROCM_ARCHis a semicolon-delimited string, but it’s expanded unquoted in the laterGPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | ...)command substitution. In/bin/sh, the;characters will be treated as command separators, which can break the build. Quote the expansion (e.g., useecho "${PYTORCH_ROCM_ARCH}") or switch to a delimiter that won’t be parsed by the shell and adapt thesedaccordingly.
ARG PYTORCH_ROCM_ARCH=gfx950;gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git submodule update --init \
&& GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist \
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+25
to
+39
| #ARG FA_SHA="b3ae4966b2567811880db10d9e040a775b99c7d7" | ||
| #ARG FA_REPO="https://github.com/ROCm/flash-attention.git" | ||
| #ARG FA_GPU_ARCH=gfx942 | ||
| #RUN git clone ${FA_REPO} && \ | ||
| # cd flash-attention && \ | ||
| # git checkout ${FA_SHA} && \ | ||
| # git submodule update --init && \ | ||
| # F='${FA_GPU_ARCH}' && \ | ||
| # if [ -z "$F" ]; then F=gfx942; fi && \ | ||
| # if [ "$F" = "native" ]; then F=gfx942; fi && \ | ||
| # GPU_ARCHS="$F" python3 setup.py bdist_wheel --dist-dir=dist && \ | ||
| # pip install dist/*.whl; | ||
| RUN git clone https://github.com/ROCm/flash-attention.git \ | ||
| && cd flash-attention \ | ||
| && FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . |
Collaborator
There was a problem hiding this comment.
@vadseshu let us pin FA to a working commit as suggested.
| pip install dist/*.whl; | ||
| # flash attn (avoid ARG name PYTORCH_ROCM_ARCH: base image ENV can shadow it and expand to "") | ||
| # ROCm flash-attention: FA_GPU_ARCH=native needs a visible GPU at compile time and fails in CI/docker. | ||
| # Coerce native/empty to gfx942 for headless CI; for MI350 pass --build-arg FA_GPU_ARCH=gfx950 (needs FA_SHA with gfx950 in setup.py). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR fixes the flash attention compilation and the steps are taken from AMD ROCM support https://github.com/Dao-AILab/flash-attention
Technical Details
The pyt_hy_video AMD image failed in two places:
Build: rocmProfileData’s rocpd_python Makefile runs pip install --user ., which is invalid inside the base image’s Python venv (/opt/venv), so make install exited with a pip error.
Runtime: pip resolved transformers 5.x as a dependency of distvae, while diffusers==0.32.2 still imports FLAX_WEIGHTS_NAME from transformers.utils, which is no longer exposed in v5. That caused ImportError when loading diffusers pipelines (e.g. hunyuan_video_usp_example.py under torchrun).
Constrain transformers to >=4.44,<5 on the same install line as diffusers / distvae so the resolver stays on a 4.x release compatible with diffusers 0.32.2.
After cloning rocmProfileData, patch rocpd_python/Makefile to replace pip install --user with pip install before make install, so installs target the venv.
Flash-attention: switch to the ROCm Triton path (FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + pip install --no-build-isolation .) so the image can be built without a GPU-scoped arch wheel build in headless CI; the previous pinned SHA / bdist_wheel path is left commented for reference.
Test Plan
Test Result
Submission Checklist